"""
Weight loaders for different formats.

This module provides loaders for:
- SafeTensor format (for AMX quantized weights)
- GGUF format (for Llamafile quantized weights)
"""

from __future__ import annotations

import os
import numpy as np
import torch
from enum import IntEnum
from safetensors import safe_open
from gguf.gguf_reader import GGUFReader


class GGMLQuantizationType(IntEnum):
    """GGML quantization type enumeration"""

    F32 = 0
    F16 = 1
    Q4_0 = 2
    Q4_1 = 3
    Q5_0 = 6
    Q5_1 = 7
    Q8_0 = 8
    Q8_1 = 9
    Q2_K = 10
    Q3_K = 11
    Q4_K = 12
    Q5_K = 13
    Q6_K = 14
    Q8_K = 15
    IQ2_XXS = 16
    IQ2_XS = 17
    IQ3_XXS = 18
    IQ1_S = 19
    IQ4_NL = 20
    IQ3_S = 21
    IQ2_S = 22
    IQ4_XS = 23
    I8 = 24
    I16 = 25
    I32 = 26
    I64 = 27
    F64 = 28
    IQ1_M = 29
    BF16 = 30


def translate_name_to_gguf(name):
    """
    Translate PyTorch tensor name to GGUF format
    """
    name = name.replace("lm_head.", "output.")
    name = name.replace("model.embed_tokens.", "token_embd.")
    name = name.replace("model.norm.", "output_norm.")
    name = name.replace("model.layers.", "blk.")
    name = name.replace(".input_layernorm", ".attn_norm")
    name = name.replace(".mlp.down_proj", ".ffn_down")
    name = name.replace(".mlp.gate_proj", ".ffn_gate")
    name = name.replace(".mlp.up_proj", ".ffn_up")
    name = name.replace(".post_attention_layernorm", ".ffn_norm")
    name = name.replace(".self_attn.q_proj", ".attn_q")
    name = name.replace(".self_attn.k_proj", ".attn_k")
    name = name.replace(".self_attn.v_proj", ".attn_v")
    name = name.replace(".self_attn.o_proj", ".attn_output")
    name = name.replace(".self_attn.qkv_proj", ".attn_qkv")
    name = name.replace(".self_attn.kv_a_proj_with_mqa", ".attn_kv_a_mqa")
    name = name.replace(".self_attn.kv_a_layernorm", ".attn_kv_a_norm")
    name = name.replace(".self_attn.kv_b_proj", ".attn_kv_b")
    name = name.replace(".self_attn.q_a_proj", ".attn_q_a")
    name = name.replace(".self_attn.q_a_layernorm", ".attn_q_a_norm")
    name = name.replace(".self_attn.q_b_proj", ".attn_q_b")
    name = name.replace(".self_attn.q_norm", ".attn_q_norm")
    name = name.replace(".self_attn.k_norm", ".attn_k_norm")
    name = name.replace(".shared_expert.", ".shared_experts.")
    name = name.replace(".shared_expert_", ".shared_experts_")
    name = name.replace(".gate_up_proj.", ".up_proj")
    name = name.replace(".mlp.shared_experts.down_proj", ".ffn_down_shexp")
    name = name.replace(".mlp.gate.e_score_correction_bias", ".exp_probs_b.bias")
    name = name.replace(".mlp.gate", ".ffn_gate_inp")
    name = name.replace(".mlp.shared_experts.gate_proj", ".ffn_gate_shexp")
    name = name.replace(".mlp.shared_experts.up_proj", ".ffn_up_shexp")
    name = name.replace(".mlp.shared_experts_gate", ".ffn_gate_inp_shexp")
    name = name.replace(".mlp.experts", "")
    name = name.replace(".mlp.experts.ffn_down_exps", ".ffn_down_exps")
    name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps")
    name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps")
    name = name.replace(".block_sparse_moe.gate.", ".ffn_gate_inp.")
    name = name.replace(".block_sparse_moe.experts", "")
    name = name.replace(".feed_forward.experts", "")
    name = name.replace(".feed_forward.router", ".ffn_gate_inp")
    name = name.replace(".feed_forward.shared_experts.down_proj", ".ffn_down_shexp")
    name = name.replace(".feed_forward.shared_experts.gate_proj", ".ffn_gate_shexp")
    name = name.replace(".feed_forward.shared_experts.up_proj", ".ffn_up_shexp")
    return name


class SafeTensorLoader:
    """
    SafeTensor format loader for AMX quantized weights.

    Supports loading tensors from .safetensors files with NUMA-sharded expert weights.
    """

    tensor_file_map: dict
    tensor_type_map: dict
    file_handle_map: dict
    tensor_device_map: dict

    def __init__(self, file_path: str):
        self.__load_tensor_file_map(file_path)

    def __load_tensor_file_map(self, file_path: str):
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Path not found: {file_path}")
        if os.path.isfile(file_path):
            folder_path = os.path.dirname(file_path)
        else:
            folder_path = file_path
        self.file_handle_map = {}
        self.tensor_file_map = {}
        self.tensor_type_map = {}
        self.tensor_device_map = {}

        found_safetensor = False
        for root, _, files in os.walk(folder_path):
            files = sorted(files)
            for file in files:
                if file.endswith(".safetensors"):
                    found_safetensor = True
                    file_path = os.path.join(root, file)
                    if file not in self.file_handle_map:
                        try:
                            handle = safe_open(file_path, framework="pt")
                            self.file_handle_map[file] = handle
                        except Exception as e:
                            print(f"Error opening Safetensor file {file_path}: {e}")
                            continue

                    f = self.file_handle_map.get(file)
                    if f is None:
                        continue
                    try:
                        for key in f.keys():
                            self.tensor_file_map[key] = file
                    except Exception as e:
                        print(f"Error reading Safetensor file {file_path}: {e}")

        if not found_safetensor:
            raise FileNotFoundError(f"No Safetensor files found in {folder_path}")

    def load_tensor(self, key: str, device: str = "cpu"):
        if key not in self.tensor_file_map:
            raise KeyError(f"Key {key} not found in Safetensor files")
        file = self.tensor_file_map[key]
        f = self.file_handle_map.get(file)
        if f is None:
            raise FileNotFoundError(f"File {file} not found in Safetensor files")
        tensor = f.get_tensor(key)
        return tensor.to(device)

    def close_all_handles(self):
        for handle in self.file_handle_map.values():
            handle.close()
        self.file_handle_map.clear()

    def load_experts(self, base_key: str, device: str = "cpu"):
        """
        Load expert weights from SafeTensor files.

        Expected format:
        - blk.{layer_index}.ffn_[up, down, gate]_exps.{expert_id}.numa.{numa_id}.weight
        - blk.{layer_index}.ffn_[up, down, gate]_exps.{expert_id}.numa.{numa_id}.scale

        Args:
            base_key: Base key like "blk.{layer_index}"
            device: Target device for tensors

        Returns:
            Dictionary with keys: up, gate, down, up_scale, gate_scale, down_scale
            Each value is a list of lists: [numa_id][expert_id] -> numpy array
        """
        up_base_key = f"{base_key}.ffn_up_exps"
        gate_base_key = f"{base_key}.ffn_gate_exps"
        down_base_key = f"{base_key}.ffn_down_exps"
        max_numa_id = -1
        max_experts_count = -1
        while self.has_tensor(f"{up_base_key}.{max_experts_count+1}.numa.{0}.weight"):
            max_experts_count += 1
        if max_experts_count == 0:
            raise ValueError(f"No experts found for key {base_key}")
        while self.has_tensor(f"{up_base_key}.{0}.numa.{max_numa_id+1}.weight"):
            max_numa_id += 1
        # Initialize empty lists to store tensors for each projection type
        up_weights = [[] for _ in range(max_numa_id + 1)]
        gate_weights = [[] for _ in range(max_numa_id + 1)]
        down_weights = [[] for _ in range(max_numa_id + 1)]
        up_scales = [[] for _ in range(max_numa_id + 1)]
        gate_scales = [[] for _ in range(max_numa_id + 1)]
        down_scales = [[] for _ in range(max_numa_id + 1)]
        for numa_id in range(max_numa_id + 1):
            for expert_id in range(max_experts_count + 1):
                up_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.weight"
                gate_key = f"{gate_base_key}.{expert_id}.numa.{numa_id}.weight"
                down_key = f"{down_base_key}.{expert_id}.numa.{numa_id}.weight"
                up_scale_key = f"{up_base_key}.{expert_id}.numa.{numa_id}.scale"
                gate_scale_key = f"{gate_base_key}.{expert_id}.numa.{numa_id}.scale"
                down_scale_key = f"{down_base_key}.{expert_id}.numa.{numa_id}.scale"
                # make sure contiguous
                up_tensor = self.load_tensor(up_key, device).numpy()
                gate_tensor = self.load_tensor(gate_key, device).numpy()
                down_tensor = self.load_tensor(down_key, device).numpy()
                up_scale_tensor = self.load_tensor(up_scale_key, device).numpy()
                gate_scale_tensor = self.load_tensor(gate_scale_key, device).numpy()
                down_scale_tensor = self.load_tensor(down_scale_key, device).numpy()

                up_weights[numa_id].append(up_tensor)
                gate_weights[numa_id].append(gate_tensor)
                down_weights[numa_id].append(down_tensor)
                up_scales[numa_id].append(up_scale_tensor)
                gate_scales[numa_id].append(gate_scale_tensor)
                down_scales[numa_id].append(down_scale_tensor)
        return {
            "up": up_weights,
            "gate": gate_weights,
            "down": down_weights,
            "up_scale": up_scales,
            "gate_scale": gate_scales,
            "down_scale": down_scales,
        }

    def has_tensor(self, name: str):
        return name in self.tensor_file_map


class GGUFLoader:
    """
    GGUF format loader using the official gguf library (gguf.gguf_reader.GGUFReader)

    This is a cleaner implementation compared to manual binary parsing.
    """

    def __init__(self, gguf_path: str):
        """
        Initialize GGUF loader from a file or directory

        Args:
            gguf_path: Path to a single GGUF file or a directory containing GGUF files
        """
        if not os.path.exists(gguf_path):
            raise FileNotFoundError(f"GGUF path not found: {gguf_path}")

        self.tensor_info = {}
        self.metadata = {}
        self.tensor_file_map = {}
        self.file_data_map = {}

        if os.path.isfile(gguf_path) and gguf_path.endswith(".gguf"):
            print(f"\n[GGUFLoader] Loading single GGUF file : {os.path.basename(gguf_path)}")
            self._load_single_file(gguf_path)
        elif os.path.isdir(gguf_path):
            print(f"\n[GGUFLoader] Loading GGUF files from directory: {gguf_path}")
            self._load_directory(gguf_path)
        else:
            raise ValueError(f"Path must be a .gguf file or a directory: {gguf_path}")

        print(f"[GGUFLoader] Summary:")
        print(f"  Files loaded: {len(self.file_data_map)}")
        print(f"  Total tensors: {len(self.tensor_info)}")
        print(f"  Metadata keys: {len(self.metadata)}")
        tensors = ["blk.0.ffn_up_exps.weight", "blk.0.ffn_gate_exps.weight", "blk.0.ffn_down_exps.weight"]
        for key in tensors:
            if key in self.tensor_info:
                info = self.tensor_info[key]
                print(f" {'.'.join(key.split('.')[2:-1])}, Dtype: {info['dtype'].name}")

    def _load_single_file(self, file_path: str):
        """Load a single GGUF file"""
        reader = GGUFReader(file_path)

        for key, field in reader.fields.items():
            value = field.parts[field.data[0]]
            if isinstance(value, bytes):
                value = value.decode("utf-8")
            elif isinstance(value, np.ndarray) and value.dtype == np.uint8:
                try:
                    value = bytes(value).decode("utf-8")
                except:
                    pass
            self.metadata[key] = value

        for tensor in reader.tensors:
            self.tensor_info[tensor.name] = {
                "shape": list(reversed(tensor.shape)),  # Reverse to match PyTorch order
                "dtype": tensor.tensor_type,
                "offset": tensor.data_offset,
                "n_elements": tensor.n_elements,
            }
            self.tensor_file_map[tensor.name] = file_path

        self.file_data_map[file_path] = np.memmap(file_path, mode="r")

    def _load_directory(self, dir_path: str):
        """Load all GGUF files from a directory (non-recursive)"""
        found_gguf = False

        for file in sorted(os.listdir(dir_path)):
            if file.endswith(".gguf"):
                found_gguf = True
                file_path = os.path.join(dir_path, file)
                print(f"  Loading: {file}")

                reader = GGUFReader(file_path)

                for key, field in reader.fields.items():
                    value = field.parts[field.data[0]]
                    if isinstance(value, bytes):
                        value = value.decode("utf-8")
                    elif isinstance(value, np.ndarray) and value.dtype == np.uint8:
                        try:
                            value = bytes(value).decode("utf-8")
                        except:
                            pass
                    self.metadata[key] = value

                for tensor in reader.tensors:
                    self.tensor_info[tensor.name] = {
                        "shape": list(reversed(tensor.shape)),
                        "dtype": tensor.tensor_type,
                        "offset": tensor.data_offset,
                        "n_elements": tensor.n_elements,
                    }
                    self.tensor_file_map[tensor.name] = file_path

                self.file_data_map[file_path] = np.memmap(file_path, mode="r")

        if not found_gguf:
            raise FileNotFoundError(f"No .gguf files found in directory: {dir_path}")

    def get_model_config(self, layer_idx: int = 0):
        """
        Extract model configuration from GGUF metadata and tensor shapes.

        Args:
            layer_idx: Layer index to inspect (default: 0)

        Returns:
            dict with keys: num_experts, num_experts_per_tok, hidden_size, moe_intermediate_size
        """
        config = {}

        arch = self.metadata.get("general.architecture", "unknown")

        num_experts = None
        for key_suffix in [
            "expert_count",
            "expert.count",
            "moe.expert_count",
            "expert_feed_forward_length",
        ]:
            key = f"{arch}.{key_suffix}"
            if key in self.metadata:
                val = self.metadata[key]
                num_experts = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)
                break

        num_experts_per_tok = None
        for key_suffix in [
            "expert_used_count",
            "expert.used_count",
            "moe.num_experts_per_tok",
        ]:
            key = f"{arch}.{key_suffix}"
            if key in self.metadata:
                val = self.metadata[key]
                num_experts_per_tok = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)
                break

        hidden_size = None
        for key_suffix in [
            "embedding_length",
            "embed_length",
            "hidden_size",
        ]:
            key = f"{arch}.{key_suffix}"
            if key in self.metadata:
                val = self.metadata[key]
                hidden_size = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)
                break

        moe_intermediate_size = None
        for key_suffix in [
            "expert_feed_forward_length",
            "feed_forward_length",
            "ffn_length",
            "intermediate_size",
        ]:
            key = f"{arch}.{key_suffix}"
            if key in self.metadata:
                val = self.metadata[key]
                moe_intermediate_size = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)
                break

        if any(v is None for v in [num_experts, hidden_size, moe_intermediate_size]):

            base_key = f"blk.{layer_idx}.ffn_gate_exps.weight"
            if base_key in self.tensor_info:
                gate_shape = self.tensor_info[base_key]["shape"]
                print(f"  Found tensor '{base_key}' with shape: {gate_shape}")

                if len(gate_shape) >= 3:
                    if num_experts is None:
                        num_experts = int(gate_shape[0])
                    if moe_intermediate_size is None:
                        moe_intermediate_size = int(gate_shape[1])
                    if hidden_size is None:
                        hidden_size = int(gate_shape[2])

        config = {
            "num_experts": num_experts,
            "num_experts_per_tok": num_experts_per_tok,
            "hidden_size": hidden_size,
            "moe_intermediate_size": moe_intermediate_size,
        }

        return config

    def print_metadata(self, filter_keywords=None):
        """
        Print GGUF file metadata for debugging.

        Args:
            filter_keywords: Optional list of keywords to filter metadata keys
        """
        print(f"\n[GGUFLoader] GGUF Metadata:")
        print(f"  Total metadata entries: {len(self.metadata)}")

        if filter_keywords:
            filtered = {
                k: v for k, v in self.metadata.items() if any(kw.lower() in k.lower() for kw in filter_keywords)
            }
            for k, v in sorted(filtered.items()):
                print(f"  {k}: {v}")
        else:
            for k, v in sorted(self.metadata.items()):
                print(f"  {k}: {v}")

    def has_tensor(self, name: str):
        """Check if tensor exists"""
        name = translate_name_to_gguf(name)
        return name in self.tensor_info

    def get_ggml_type(self, name: str):
        """Get GGML type of a tensor"""
        name = translate_name_to_gguf(name)
        if name not in self.tensor_info:
            raise KeyError(f"Tensor '{name}' not found in GGUF files")
        return self.tensor_info[name]["dtype"]

    def get_undequanted_tensor_and_ggml_type(self, name: str):
        """
        Get tensor data and its GGML type without dequantizing

        Args:
            name: Tensor name (in PyTorch format, will be translated to GGUF format)

        Returns:
            (data, ggml_type): Tuple of tensor data and GGML quantization type
        """
        name = translate_name_to_gguf(name)

        if name not in self.tensor_info:
            raise KeyError(f"Tensor '{name}' not found in GGUF files")

        info = self.tensor_info[name]
        file_path = self.tensor_file_map[name]
        mmap_data = self.file_data_map[file_path]

        offset = info["offset"]
        n_elements = info["n_elements"]
        ggml_type = info["dtype"]

        GGML_QUANT_SIZES = {
            GGMLQuantizationType.F32: (1, 4),
            GGMLQuantizationType.F16: (1, 2),
            GGMLQuantizationType.BF16: (1, 2),
            GGMLQuantizationType.Q4_0: (32, 2 + 16),
            GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
            GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
            GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
            GGMLQuantizationType.Q8_0: (32, 2 + 32),
            GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
            GGMLQuantizationType.Q2_K: (256, 2 + 2 + 256 // 16 + 256 // 4),
            GGMLQuantizationType.Q3_K: (256, 2 + 256 // 4 + 256 // 8 + 12),
            GGMLQuantizationType.Q4_K: (256, 2 + 2 + 256 // 2 + 12),
            GGMLQuantizationType.Q5_K: (256, 2 + 2 + 256 // 2 + 256 // 8 + 12),
            GGMLQuantizationType.Q6_K: (256, 2 + 256 // 2 + 256 // 4 + 256 // 16),
            GGMLQuantizationType.Q8_K: (256, 4 + 256 + 256 // 8),
            GGMLQuantizationType.IQ2_XXS: (256, 2 + 256 // 4),
            GGMLQuantizationType.IQ2_XS: (256, 2 + 256 // 4 + 256 // 32),
            GGMLQuantizationType.IQ3_XXS: (256, 2 + 256 // 4 + 256 // 8),
            GGMLQuantizationType.IQ1_S: (256, 2 + 256 // 8 + 256 // 16),
            GGMLQuantizationType.IQ4_NL: (32, 2 + 16),
            GGMLQuantizationType.IQ3_S: (256, 2 + 256 // 4 + 256 // 8 + 256 // 32 + 4),
            GGMLQuantizationType.IQ2_S: (256, 2 + 256 // 4 + 256 // 16),
            GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + 256 // 2 + 256 // 64),
            GGMLQuantizationType.I8: (1, 1),
            GGMLQuantizationType.I16: (1, 2),
            GGMLQuantizationType.I32: (1, 4),
            GGMLQuantizationType.I64: (1, 8),
            GGMLQuantizationType.F64: (1, 8),
            GGMLQuantizationType.IQ1_M: (256, 256 // 8 + 256 // 16 + 256 // 32),
        }

        block_size, type_size = GGML_QUANT_SIZES[ggml_type]
        n_bytes = n_elements * type_size // block_size

        data_bytes = mmap_data[offset : offset + n_bytes]
        data = torch.from_numpy(np.frombuffer(data_bytes, dtype=np.uint8).copy())

        return data, ggml_type
